{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# !! This is experimental notebook. !!\n",
    "\n",
    "## Proper integration with TensorFlow is available in the H2O's DeepWater project."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# H2O TensorFlow Deep Learning Demo\n",
    "\n",
    "You can also follow this [Youtube video](https://www.youtube.com/watch?v=62TFK641gG8&feature=youtu.be).\n",
    "## Prerequisites\n",
    "1. Install TensorFlow from [https://www.tensorflow.org](https://www.tensorflow.org)\n",
    "2. Download Sparkling Water from [here](http://h2o-release.s3.amazonaws.com/sparkling-water/rel-2.0/latest.html)\n",
    "3. Follow [instructions to setup PySparkling](https://github.com/h2oai/sparkling-water/blob/master/py/README.rst) (especially steps 1 and 2)\n",
    "4. Launch a Jupyter Notebook that connects to PySparkling:\n",
    "```\n",
    "cd ~/Downloads\n",
    "unzip sparkling-water-2.0.0.zip\n",
    "cd sparkling-water-2.0.0\n",
    "~/sparkling-water-2.0.0$ PYSPARK_DRIVER_PYTHON=ipython PYSPARK_DRIVER_PYTHON_OPTS=\"notebook\" bin/pysparkling\n",
    "```\n",
    "5. Point the Notebook to [this file](https://raw.githubusercontent.com/h2oai/sparkling-water/master/py/examples/notebooks/TensorFlowDeepLearning.ipynb) (e.g., download it first, then upload into the Notebook)\n",
    "\n",
    "### Introduction\n",
    "In this tutorial, we'll build a simple 2-layer deep artificial neural network to classify handwritten digits [MNIST](http://yann.lecun.com/exdb/mnist/). If you are not familiar with these terms, please check out our [Deep Learning Booklet](https://github.com/h2oai/h2o-3/blob/master/h2o-docs/src/booklets/v2_2015/PDFs/online/DeepLearning_Vignette.pdf)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Connect to H2O\n",
    "\n",
    "We connect to an H2O cluster (here: 3 nodes), and import the MNIST dataset (pre-split into 60k rows for training and 10k rows for testing). Each row has 28^2=784 grayscale pixel values from 0 to 255."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python2.7/site-packages/IPython/core/formatters.py:92: DeprecationWarning: DisplayFormatter._ipython_display_formatter_default is deprecated: use @default decorator instead.\n",
      "  def _ipython_display_formatter_default(self):\n",
      "/usr/local/lib/python2.7/site-packages/IPython/core/formatters.py:98: DeprecationWarning: DisplayFormatter._formatters_default is deprecated: use @default decorator instead.\n",
      "  def _formatters_default(self):\n",
      "/usr/local/lib/python2.7/site-packages/IPython/core/formatters.py:677: DeprecationWarning: PlainTextFormatter._deferred_printers_default is deprecated: use @default decorator instead.\n",
      "  def _deferred_printers_default(self):\n",
      "/usr/local/lib/python2.7/site-packages/IPython/core/formatters.py:669: DeprecationWarning: PlainTextFormatter._singleton_printers_default is deprecated: use @default decorator instead.\n",
      "  def _singleton_printers_default(self):\n",
      "/usr/local/lib/python2.7/site-packages/IPython/core/formatters.py:672: DeprecationWarning: PlainTextFormatter._type_printers_default is deprecated: use @default decorator instead.\n",
      "  def _type_printers_default(self):\n",
      "/usr/local/lib/python2.7/site-packages/IPython/core/formatters.py:669: DeprecationWarning: PlainTextFormatter._singleton_printers_default is deprecated: use @default decorator instead.\n",
      "  def _singleton_printers_default(self):\n",
      "/usr/local/lib/python2.7/site-packages/IPython/core/formatters.py:672: DeprecationWarning: PlainTextFormatter._type_printers_default is deprecated: use @default decorator instead.\n",
      "  def _type_printers_default(self):\n",
      "/usr/local/lib/python2.7/site-packages/IPython/core/formatters.py:677: DeprecationWarning: PlainTextFormatter._deferred_printers_default is deprecated: use @default decorator instead.\n",
      "  def _deferred_printers_default(self):\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div style=\"overflow:auto\"><table style=\"width:50%\"><tr><td>H2O cluster uptime: </td>\n",
       "<td>6 seconds 696 milliseconds </td></tr>\n",
       "<tr><td>H2O cluster version: </td>\n",
       "<td>3.8.2.6</td></tr>\n",
       "<tr><td>H2O cluster name: </td>\n",
       "<td>sparkling-water-arno_-1896306748</td></tr>\n",
       "<tr><td>H2O cluster total nodes: </td>\n",
       "<td>3</td></tr>\n",
       "<tr><td>H2O cluster total free memory: </td>\n",
       "<td>2.88 GB</td></tr>\n",
       "<tr><td>H2O cluster total cores: </td>\n",
       "<td>24</td></tr>\n",
       "<tr><td>H2O cluster allowed cores: </td>\n",
       "<td>24</td></tr>\n",
       "<tr><td>H2O cluster healthy: </td>\n",
       "<td>True</td></tr>\n",
       "<tr><td>H2O Connection ip: </td>\n",
       "<td>172.16.2.20</td></tr>\n",
       "<tr><td>H2O Connection port: </td>\n",
       "<td>54327</td></tr>\n",
       "<tr><td>H2O Connection proxy: </td>\n",
       "<td>None</td></tr>\n",
       "<tr><td>Python Version: </td>\n",
       "<td>2.7.11</td></tr></table></div>"
      ],
      "text/plain": [
       "------------------------------  --------------------------------\n",
       "H2O cluster uptime:             6 seconds 696 milliseconds\n",
       "H2O cluster version:            3.8.2.6\n",
       "H2O cluster name:               sparkling-water-arno_-1896306748\n",
       "H2O cluster total nodes:        3\n",
       "H2O cluster total free memory:  2.88 GB\n",
       "H2O cluster total cores:        24\n",
       "H2O cluster allowed cores:      24\n",
       "H2O cluster healthy:            True\n",
       "H2O Connection ip:              172.16.2.20\n",
       "H2O Connection port:            54327\n",
       "H2O Connection proxy:\n",
       "Python Version:                 2.7.11\n",
       "------------------------------  --------------------------------"
      ]
     },
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "H2OContext: ip=172.16.2.20, port=54327 (open UI at http://172.16.2.20:54327 )\n",
      "\n",
      "Parse Progress: [##################################################] 100%\n",
      "\n",
      "Parse Progress: [##################################################] 100%\n"
     ]
    }
   ],
   "source": [
    "## Read MNIST data into H2O\n",
    "import h2o\n",
    "from pysparkling import H2OContext\n",
    "h2o.__version__\n",
    "hc = H2OContext.getOrCreate(spark)\n",
    "print(hc)\n",
    "DATASET_DIR=\"http://s3.amazonaws.com/h2o-public-test-data/bigdata/laptop/mnist\"\n",
    "train_frame = h2o.import_file(\"{}/{}\".format(DATASET_DIR, \"train.csv.gz\"))\n",
    "test_frame = h2o.import_file(\"{}/{}\".format(DATASET_DIR, \"test.csv.gz\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Let's first confirm that TensorFlow is working properly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "## can simulate larger clusters here\n",
    "NODES=3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Sparkling, TensorFlow!', 'Sparkling, TensorFlow!', 'Sparkling, TensorFlow!']"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import tensorflow\n",
    "## Initialize TensorFlow session and test it\n",
    "def map_fun(i):\n",
    "  import tensorflow as tf\n",
    "  with tf.Graph().as_default() as g:\n",
    "    hello = tf.constant('Sparkling, TensorFlow!', name=\"hello_constant\")\n",
    "    with tf.Session() as sess:\n",
    "      return sess.run(hello)\n",
    "sc.parallelize(list(range(NODES)), NODES).map(map_fun).collect()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we expose the data in H2O through the Spark DataFrame API (no copy made), such that the Python process with TensorFlow can access the data from the PySpark(ling) context."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "train_df = hc.as_spark_frame(train_frame).repartition(NODES)\n",
    "test_df = hc.as_spark_frame(test_frame).repartition(NODES)\n",
    "#train_df.printSchema()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define a TensorFlow Deep Learning model \n",
    "Now, we define a TensorFlow Deep Learning model with 2 hidden layers of 50 neurons each, and the Rectifier activation function. We use the Softmax function to turn the 10 output neuron activation values into 10 class probabilities. We initialize the weights and biases with Gaussian noise. We train the model with Gradient descent with a fixed learning rate, no momentum, and use mini-batch for faster training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "## Define the number of hidden neurons per layer\n",
    "HN=50\n",
    "\n",
    "# - it loads local training data into numpy array (from Spark -> Python)\n",
    "# - train TF Deep Learning model with 2 hidden layer\n",
    "# - output accuracy on training data\n",
    "def create_nn(data_train, data_test, iterations, batch_size):\n",
    "    ## input\n",
    "    x = tf.placeholder(tf.float32, [None, 784])\n",
    "    ## weights\n",
    "    W = [tf.Variable(tf.random_normal([784,HN],stddev=0.1))\n",
    "        ,tf.Variable(tf.random_normal([HN, HN],stddev=0.1))\n",
    "        ,tf.Variable(tf.random_normal([HN, 10],stddev=0.1))]\n",
    "    ## biases\n",
    "    b = [tf.Variable(tf.random_normal([HN],    stddev=0.1))\n",
    "        ,tf.Variable(tf.random_normal([HN],    stddev=0.1))\n",
    "        ,tf.Variable(tf.random_normal([10],    stddev=0.1))]\n",
    "    ## hidden layer activation\n",
    "    h1 = tf.nn.relu(   tf.matmul(x,  W[0]) + b[0])\n",
    "    h2 = tf.nn.relu(   tf.matmul(h1, W[1]) + b[1])\n",
    "    ## output\n",
    "    y = tf.nn.softmax( tf.matmul(h2, W[2]) + b[2])\n",
    "    ## storage for actual labels\n",
    "    y_ = tf.placeholder(tf.float32, [None, 10])\n",
    "    ## cost function\n",
    "    cross_entropy = -tf.reduce_sum(y_*tf.log(y))                    \n",
    "    ## optimizer\n",
    "    train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)\n",
    "    \n",
    "    # Train the model\n",
    "    init = tf.initialize_all_variables()\n",
    "    sess = tf.Session()\n",
    "    sess.run(init)\n",
    "    print(\"Training TensorFlow Deep Learning model\")\n",
    "    for i in range(iterations):\n",
    "      #print(\"TensorFlow iter: \", i, \" session: \", sess)\n",
    "      batch_xs, batch_ys = data_train.next_batch(batch_size)\n",
    "      sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})\n",
    "        \n",
    "    model = [(sess.run(W[0]),sess.run(W[1]),sess.run(W[2]),sess.run(b[0]),sess.run(b[1]),sess.run(b[2]))]\n",
    "\n",
    "    # Model evaluation\n",
    "    correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))\n",
    "    batch_xs, batch_ys = data_test.next_batch(batch_size)\n",
    "    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n",
    "    print(\"Training Accuracy:\", sess.run(accuracy, feed_dict={x: batch_xs, y_: batch_ys}))\n",
    "    #print(sess.run(tf.argmax(y,1), feed_dict={x: batch_xs, y_: batch_ys}))\n",
    "    \n",
    "    sess.close()\n",
    "    return iter(model)\n",
    "    \n",
    "    # Export the model\n",
    "    #from tensorflow_serving.session_bundle import exporter\n",
    "    #export_path = \"/tmp/xxx/\"\n",
    "    #saver = tf.train.Saver(sharded=True)\n",
    "    #model_exporter = exporter.Exporter(saver)\n",
    "    #signature = exporter.classification_signature(input_tensor=x, scores_tensor=y)\n",
    "    #model_exporter.init(sess.graph.as_graph_def(), default_graph_signature=signature)\n",
    "    #model_exporter.export(export_path, tf.constant(FLAGS.export_version), sess)\n",
    "    \n",
    "## Internal Helpers\n",
    "\n",
    "# Sampling with replacement to provide a batch size\n",
    "# Load everything into numpy datastructure\n",
    "import numpy as np\n",
    "\n",
    "def expand1hot(response, levels):\n",
    "    nrows = response.shape[0]\n",
    "    result = np.zeros((nrows, levels), dtype=np.float32)\n",
    "    result[np.arange(nrows), response.astype(np.int8)] = 1.0\n",
    "    return result\n",
    "\n",
    "class RowData:\n",
    "    def __init__(self, it):\n",
    "        self._part_array = np.array([ [a for a in x] for x in it], dtype=np.float32)\n",
    "        # Definition of input features\n",
    "        self._x = list(range(784))\n",
    "        # Index of response\n",
    "        self._y = 784\n",
    "\n",
    "    def next_batch(self, n):\n",
    "        # Sample from local data without replacement\n",
    "        dim = self._part_array.shape[0] # number of rows\n",
    "        sample = np.random.choice(dim, n, replace=False)\n",
    "        data = self._part_array[sample, :]\n",
    "        # Data coming from H2O, pixel values are 0..255 -> normalize to 0..1\n",
    "        # FIXME: this should be done via RDD or H2O API directly !\n",
    "        train = data[:, self._x]/255\n",
    "        response = expand1hot(data[:, self._y], 10)\n",
    "        return (train, response)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run TensorFlow on each H2O/PySparkling node\n",
    "\n",
    "We use the Spark Map/Reduce paradigm to distribute the training across multiple worker nodes, each node trains on its local data (stored in H2O, accessed by TensorFlow via JVM -> Python serialization provided by the PySpark(ling) API).\n",
    "\n",
    "Here, we train only for a short time for demo purposes. This is certainly not the best quality model we can build."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Number of batches to iterate\n",
    "ITERATIONS = 100\n",
    "# Batch size (per iteration)\n",
    "BATCH_SIZE = 100\n",
    "# Use MNIST dataset provided by TensorFlow - for debugging only\n",
    "USE_TF_MNIST=False\n",
    "\n",
    "def train_nn(iterations, batch_size, use_tf_mnist=False):\n",
    "    def perPartition(it):\n",
    "        if not use_tf_mnist:\n",
    "            train_data = RowData(it)\n",
    "            test_data = train_data\n",
    "        else:\n",
    "            from tensorflow.examples.tutorials.mnist import input_data\n",
    "            mnist = input_data.read_data_sets('MNIST_data', one_hot=True)\n",
    "            train_data = mnist.train\n",
    "            test_data = mnist.train\n",
    "            \n",
    "        return create_nn(train_data, test_data, iterations, batch_size)\n",
    "        \n",
    "    return perPartition\n",
    "\n",
    "coeffs_per_node = train_df.rdd.mapPartitions(train_nn(ITERATIONS, BATCH_SIZE, USE_TF_MNIST)).collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false,
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3\n",
      "6\n"
     ]
    }
   ],
   "source": [
    "# Now, we have the weights and biases for each node\n",
    "print(len(coeffs_per_node))    ## Number of nodes\n",
    "print(len(coeffs_per_node[0])) ## Number of weight and bias arrays"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Convert the TensorFlow model into a H2O Deep Learning model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Parse Progress: [##################################################] 100%\n",
      "\n",
      "Parse Progress: [##################################################] 100%\n",
      "\n",
      "Parse Progress: [##################################################] 100%\n",
      "\n",
      "Parse Progress: [##################################################] 100%\n",
      "\n",
      "Parse Progress: [##################################################] 100%\n",
      "\n",
      "Parse Progress: [##################################################] 100%\n",
      "[[50, 784], [50, 50], [10, 50]]\n",
      "[[50, 1], [50, 1], [10, 1]]\n"
     ]
    }
   ],
   "source": [
    "# Average the weights and biases across all node-local models\n",
    "avg_coeffs = [c for c in coeffs_per_node[0]]\n",
    "for i in range(0,len(avg_coeffs)):\n",
    "    for node in range(1,NODES):\n",
    "        avg_coeffs[i] = avg_coeffs[i] + coeffs_per_node[node][i]\n",
    "avg_coeffs = [c/NODES for c in avg_coeffs]\n",
    "\n",
    "num_weights=int(len(coeffs_per_node[0])/2)\n",
    "\n",
    "## Convert the model coefficients (weights/biases) to H2O Frames\n",
    "H2O_w = [h2o.H2OFrame(np.transpose(c)) for c in avg_coeffs[0:num_weights]]\n",
    "H2O_b = [h2o.H2OFrame(np.transpose(np.matrix(c))) for c in avg_coeffs[num_weights:2*num_weights]]\n",
    "\n",
    "print([c.dim for c in H2O_w])\n",
    "print([c.dim for c in H2O_b])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "deeplearning Model Build Progress: [                                                  ] 00%\n"
     ]
    }
   ],
   "source": [
    "#Initialize an H2O Model with those weights/biases\n",
    "from h2o.estimators.deeplearning import H2ODeepLearningEstimator\n",
    "\n",
    "## Create an H2O Deep Learning model from the TensorFlow model\n",
    "dlmodel = H2ODeepLearningEstimator(\n",
    "    model_id=\"model_from_TF\", ## we want to be able to find the model in Flow later\n",
    "    hidden=[HN,HN],           ## same Network layout as TF - two hidden layers\n",
    "    epochs=0,                 ## no training done in H2O - just copy over the model from TF\n",
    "    ignore_const_cols=False,  ## keep all input features (unless we also drop const cols in TF)\n",
    "    sparse=True,              ## faster as 0 input remains 0 -> sparse activation -> sparse updates\n",
    "    variable_importances=True\n",
    "    ### Initialize the H2O model with the TensorFlow model state\n",
    "    ### Requires H2O 3.8.2.1 or later\n",
    "    ,initial_weights=[H2O_w[0],H2O_w[1],H2O_w[2]]\n",
    "    ,initial_biases =[H2O_b[0],H2O_b[1],H2O_b[2]]\n",
    ")\n",
    "train_frame[784] = train_frame[784].asfactor()\n",
    "dlmodel.train(x=list(range(784)),y=784,training_frame=train_frame)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Score the H2O Deep Learning model in H2O (with the TensorFlow state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "ModelMetricsMultinomial: deeplearning\n",
      "** Reported on test data. **\n",
      "\n",
      "MSE: 0.778273808871\n",
      "R^2: 0.907184785182\n",
      "LogLoss: 2.18184795723\n",
      "\n",
      "Confusion Matrix: vertical: actual; across: predicted\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div style=\"overflow:auto\"><table style=\"width:50%\"><tr><td><b>0</b></td>\n",
       "<td><b>1</b></td>\n",
       "<td><b>2</b></td>\n",
       "<td><b>3</b></td>\n",
       "<td><b>4</b></td>\n",
       "<td><b>5</b></td>\n",
       "<td><b>6</b></td>\n",
       "<td><b>7</b></td>\n",
       "<td><b>8</b></td>\n",
       "<td><b>9</b></td>\n",
       "<td><b>Error</b></td>\n",
       "<td><b>Rate</b></td></tr>\n",
       "<tr><td>63.0</td>\n",
       "<td>0.0</td>\n",
       "<td>0.0</td>\n",
       "<td>0.0</td>\n",
       "<td>0.0</td>\n",
       "<td>1.0</td>\n",
       "<td>132.0</td>\n",
       "<td>0.0</td>\n",
       "<td>784.0</td>\n",
       "<td>0.0</td>\n",
       "<td>0.9357143</td>\n",
       "<td>917 / 980</td></tr>\n",
       "<tr><td>0.0</td>\n",
       "<td>0.0</td>\n",
       "<td>1.0</td>\n",
       "<td>0.0</td>\n",
       "<td>0.0</td>\n",
       "<td>3.0</td>\n",
       "<td>104.0</td>\n",
       "<td>0.0</td>\n",
       "<td>1027.0</td>\n",
       "<td>0.0</td>\n",
       "<td>1.0</td>\n",
       "<td>1,135 / 1,135</td></tr>\n",
       "<tr><td>0.0</td>\n",
       "<td>0.0</td>\n",
       "<td>102.0</td>\n",
       "<td>0.0</td>\n",
       "<td>0.0</td>\n",
       "<td>0.0</td>\n",
       "<td>831.0</td>\n",
       "<td>1.0</td>\n",
       "<td>95.0</td>\n",
       "<td>3.0</td>\n",
       "<td>0.9011628</td>\n",
       "<td>930 / 1,032</td></tr>\n",
       "<tr><td>0.0</td>\n",
       "<td>0.0</td>\n",
       "<td>1.0</td>\n",
       "<td>0.0</td>\n",
       "<td>0.0</td>\n",
       "<td>0.0</td>\n",
       "<td>122.0</td>\n",
       "<td>1.0</td>\n",
       "<td>884.0</td>\n",
       "<td>2.0</td>\n",
       "<td>1.0</td>\n",
       "<td>1,010 / 1,010</td></tr>\n",
       "<tr><td>0.0</td>\n",
       "<td>0.0</td>\n",
       "<td>0.0</td>\n",
       "<td>0.0</td>\n",
       "<td>179.0</td>\n",
       "<td>1.0</td>\n",
       "<td>682.0</td>\n",
       "<td>0.0</td>\n",
       "<td>107.0</td>\n",
       "<td>13.0</td>\n",
       "<td>0.8177189</td>\n",
       "<td>803 / 982</td></tr>\n",
       "<tr><td>0.0</td>\n",
       "<td>0.0</td>\n",
       "<td>0.0</td>\n",
       "<td>0.0</td>\n",
       "<td>1.0</td>\n",
       "<td>25.0</td>\n",
       "<td>436.0</td>\n",
       "<td>0.0</td>\n",
       "<td>430.0</td>\n",
       "<td>0.0</td>\n",
       "<td>0.9719731</td>\n",
       "<td>867 / 892</td></tr>\n",
       "<tr><td>0.0</td>\n",
       "<td>0.0</td>\n",
       "<td>3.0</td>\n",
       "<td>0.0</td>\n",
       "<td>1.0</td>\n",
       "<td>1.0</td>\n",
       "<td>930.0</td>\n",
       "<td>0.0</td>\n",
       "<td>21.0</td>\n",
       "<td>2.0</td>\n",
       "<td>0.0292276</td>\n",
       "<td>28 / 958</td></tr>\n",
       "<tr><td>0.0</td>\n",
       "<td>0.0</td>\n",
       "<td>2.0</td>\n",
       "<td>1.0</td>\n",
       "<td>2.0</td>\n",
       "<td>0.0</td>\n",
       "<td>298.0</td>\n",
       "<td>1.0</td>\n",
       "<td>720.0</td>\n",
       "<td>4.0</td>\n",
       "<td>0.9990272</td>\n",
       "<td>1,027 / 1,028</td></tr>\n",
       "<tr><td>0.0</td>\n",
       "<td>0.0</td>\n",
       "<td>0.0</td>\n",
       "<td>0.0</td>\n",
       "<td>1.0</td>\n",
       "<td>3.0</td>\n",
       "<td>827.0</td>\n",
       "<td>0.0</td>\n",
       "<td>142.0</td>\n",
       "<td>1.0</td>\n",
       "<td>0.8542094</td>\n",
       "<td>832 / 974</td></tr>\n",
       "<tr><td>0.0</td>\n",
       "<td>0.0</td>\n",
       "<td>1.0</td>\n",
       "<td>1.0</td>\n",
       "<td>11.0</td>\n",
       "<td>2.0</td>\n",
       "<td>816.0</td>\n",
       "<td>0.0</td>\n",
       "<td>174.0</td>\n",
       "<td>4.0</td>\n",
       "<td>0.9960357</td>\n",
       "<td>1,005 / 1,009</td></tr>\n",
       "<tr><td>63.0</td>\n",
       "<td>0.0</td>\n",
       "<td>110.0</td>\n",
       "<td>2.0</td>\n",
       "<td>195.0</td>\n",
       "<td>36.0</td>\n",
       "<td>5178.0</td>\n",
       "<td>3.0</td>\n",
       "<td>4384.0</td>\n",
       "<td>29.0</td>\n",
       "<td>0.8554</td>\n",
       "<td>8,554 / 10,000</td></tr></table></div>"
      ],
      "text/plain": [
       "0    1    2    3    4    5    6     7    8     9    Error      Rate\n",
       "---  ---  ---  ---  ---  ---  ----  ---  ----  ---  ---------  --------------\n",
       "63   0    0    0    0    1    132   0    784   0    0.935714   917 / 980\n",
       "0    0    1    0    0    3    104   0    1027  0    1          1,135 / 1,135\n",
       "0    0    102  0    0    0    831   1    95    3    0.901163   930 / 1,032\n",
       "0    0    1    0    0    0    122   1    884   2    1          1,010 / 1,010\n",
       "0    0    0    0    179  1    682   0    107   13   0.817719   803 / 982\n",
       "0    0    0    0    1    25   436   0    430   0    0.971973   867 / 892\n",
       "0    0    3    0    1    1    930   0    21    2    0.0292276  28 / 958\n",
       "0    0    2    1    2    0    298   1    720   4    0.999027   1,027 / 1,028\n",
       "0    0    0    0    1    3    827   0    142   1    0.854209   832 / 974\n",
       "0    0    1    1    11   2    816   0    174   4    0.996036   1,005 / 1,009\n",
       "63   0    110  2    195  36   5178  3    4384  29   0.8554     8,554 / 10,000"
      ]
     },
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Top-10 Hit Ratios: \n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div style=\"overflow:auto\"><table style=\"width:50%\"><tr><td><b>k</b></td>\n",
       "<td><b>hit_ratio</b></td></tr>\n",
       "<tr><td>1</td>\n",
       "<td>0.1446</td></tr>\n",
       "<tr><td>2</td>\n",
       "<td>0.3175</td></tr>\n",
       "<tr><td>3</td>\n",
       "<td>0.5003</td></tr>\n",
       "<tr><td>4</td>\n",
       "<td>0.6101</td></tr>\n",
       "<tr><td>5</td>\n",
       "<td>0.6873000</td></tr>\n",
       "<tr><td>6</td>\n",
       "<td>0.733</td></tr>\n",
       "<tr><td>7</td>\n",
       "<td>0.8392000</td></tr>\n",
       "<tr><td>8</td>\n",
       "<td>0.9147</td></tr>\n",
       "<tr><td>9</td>\n",
       "<td>0.9926000</td></tr>\n",
       "<tr><td>10</td>\n",
       "<td>0.9999999</td></tr></table></div>"
      ],
      "text/plain": [
       "k    hit_ratio\n",
       "---  -----------\n",
       "1    0.1446\n",
       "2    0.3175\n",
       "3    0.5003\n",
       "4    0.6101\n",
       "5    0.6873\n",
       "6    0.733\n",
       "7    0.8392\n",
       "8    0.9147\n",
       "9    0.9926\n",
       "10   1"
      ]
     },
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "text/plain": []
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "## We can let H2O evaluate the performance of the TensorFlow model on the test set\n",
    "dlmodel.model_performance(test_frame)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8554"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "## Overall classification error of the TF model (in H2O form) on the test set - not very good yet - needs more training\n",
    "dlmodel.model_performance(test_frame).confusion_matrix()['Error'][-1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Extract the Java scoring code (POJO) for the TensorFlow model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#dlmodel.download_pojo()  ## too large for Github"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "```\n",
    "Filepath: /model_from_TF.java\n",
    "/*\n",
    "  Licensed under the Apache License, Version 2.0\n",
    "    http://www.apache.org/licenses/LICENSE-2.0.html\n",
    "\n",
    "  AUTOGENERATED BY H2O at 2016-06-06T18:29:29.587-07:00\n",
    "  3.8.2.3\n",
    "  \n",
    "  Standalone prediction code with sample test data for DeepLearningModel named model_from_TF\n",
    "\n",
    "  How to download, compile and execute:\n",
    "      mkdir tmpdir\n",
    "      cd tmpdir\n",
    "      curl http:/localhost/127.0.0.1:54321/3/h2o-genmodel.jar > h2o-genmodel.jar\n",
    "      curl http:/localhost/127.0.0.1:54321/3/Models.java/model_from_TF > model_from_TF.java\n",
    "      javac -cp h2o-genmodel.jar -J-Xmx2g -J-XX:MaxPermSize=128m model_from_TF.java\n",
    "\n",
    "     (Note:  Try java argument -XX:+PrintCompilation to show runtime JIT compiler behavior.)\n",
    "*/\n",
    "import java.util.Map;\n",
    "import hex.genmodel.GenModel;\n",
    "import hex.genmodel.annotations.ModelPojo;\n",
    "...\n",
    "  // Pass in data in a double[], pre-aligned to the Model's requirements.\n",
    "  // Jam predictions into the preds[] array; preds[0] is reserved for the\n",
    "  // main prediction (class for classifiers or value for regression),\n",
    "  // and remaining columns hold a probability distribution for classifiers.\n",
    "  public final double[] score0( double[] data, double[] preds ) {\n",
    "    java.util.Arrays.fill(preds,0);\n",
    "    java.util.Arrays.fill(NUMS,0);\n",
    "    int i = 0, ncats = 0;\n",
    "    final int n = data.length;\n",
    "    for(; i<n; ++i) {\n",
    "      NUMS[i] = Double.isNaN(data[i]) ? 0 : (data[i] - NORMSUB.VALUES[i])*NORMMUL.VALUES[i];\n",
    "    }\n",
    "    java.util.Arrays.fill(ACTIVATION[0],0);\n",
    "    for (i=0; i<NUMS.length; ++i) {\n",
    "      ACTIVATION[0][CATOFFSETS[CATOFFSETS.length-1] + i] = Double.isNaN(NUMS[i]) ? 0 : NUMS[i];\n",
    "    }\n",
    "    for (i=1; i<ACTIVATION.length; ++i) {\n",
    "      java.util.Arrays.fill(ACTIVATION[i],0);\n",
    "      int cols = ACTIVATION[i-1].length;\n",
    "      int rows = ACTIVATION[i].length;\n",
    "      int extra=cols-cols%8;\n",
    "      int multiple = (cols/8)*8-1;\n",
    "      int idx = 0;\n",
    "      float[] a = WEIGHT[i];\n",
    "      double[] x = ACTIVATION[i-1];\n",
    "      double[] y = BIAS[i];\n",
    "      double[] res = ACTIVATION[i];\n",
    "      for (int row=0; row<rows; ++row) {\n",
    "        double psum0 = 0, psum1 = 0, psum2 = 0, psum3 = 0, psum4 = 0, psum5 = 0, psum6 = 0, psum7 = 0;\n",
    "        for (int col = 0; col < multiple; col += 8) {\n",
    "          int off = idx + col;\n",
    "          psum0 += a[off    ] * x[col    ];\n",
    "          psum1 += a[off + 1] * x[col + 1];\n",
    "          psum2 += a[off + 2] * x[col + 2];\n",
    "          psum3 += a[off + 3] * x[col + 3];\n",
    "          psum4 += a[off + 4] * x[col + 4];\n",
    "          psum5 += a[off + 5] * x[col + 5];\n",
    "          psum6 += a[off + 6] * x[col + 6];\n",
    "          psum7 += a[off + 7] * x[col + 7];\n",
    "        }\n",
    "        res[row] += psum0 + psum1 + psum2 + psum3;\n",
    "        res[row] += psum4 + psum5 + psum6 + psum7;\n",
    "        for (int col = extra; col < cols; col++)\n",
    "          res[row] += a[idx + col] * x[col];\n",
    "        res[row] += y[row];\n",
    "        idx += cols;\n",
    "      }\n",
    "      if (i<ACTIVATION.length-1) {\n",
    "        for (int r=0; r<ACTIVATION[i].length; ++r) {\n",
    "          ACTIVATION[i][r] = Math.max(0, ACTIVATION[i][r]);\n",
    "        }\n",
    "      }\n",
    "      if (i == ACTIVATION.length-1) {\n",
    "        double max = ACTIVATION[i][0];\n",
    "        for (int r=1; r<ACTIVATION[i].length; r++) {\n",
    "          if (ACTIVATION[i][r]>max) max = ACTIVATION[i][r];\n",
    "        }\n",
    "        double scale = 0;\n",
    "        for (int r=0; r<ACTIVATION[i].length; r++) {\n",
    "          ACTIVATION[i][r] = Math.exp(ACTIVATION[i][r] - max);\n",
    "          scale += ACTIVATION[i][r];\n",
    "        }\n",
    "        for (int r=0; r<ACTIVATION[i].length; r++) {\n",
    "          if (Double.isNaN(ACTIVATION[i][r]))\n",
    "            throw new RuntimeException(\"Numerical instability, predicted NaN.\");\n",
    "          ACTIVATION[i][r] /= scale;\n",
    "          preds[r+1] = ACTIVATION[i][r];\n",
    "        }\n",
    "      }\n",
    "    }\n",
    "    preds[0] = hex.genmodel.GenModel.getPrediction(preds, PRIOR_CLASS_DISTRIB, data, 0.5);\n",
    "    return preds;\n",
    "  }\n",
    "}\n",
    "...\n",
    "\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Continue Training the Deep Learning model in H2O"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "deeplearning Model Build Progress: [##################################################] 100%\n"
     ]
    }
   ],
   "source": [
    "## Train in H2O for 1 more epoch (one full pass over the training data)\n",
    "dlmodel.epochs=1\n",
    "dlmodel.train(x=list(range(784)),y=784,training_frame=train_frame)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.0726"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "## Check the classification error of the H2O model after a bit of training in H2O - much better!\n",
    "p=dlmodel.model_performance(test_frame)\n",
    "p.confusion_matrix()['Error'][-1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inspect the model in Flow\n",
    "Since the model is now in H2O, we can inspect it from [Flow](http://localhost:54321), run ```print(hc)``` to see the URL to connect to Flow.\n",
    "\n",
    "![alt text](./getCloud.png \"Cloud status\")\n",
    "\n",
    "For example, we can graphically inspect the variable importance or the confusion matrix. We can also score the model on the test set in Flow, continue training the model from this checkpoint, or inspect the Java scoring code (POJO). We highly recommend you to get familiar with Flow if you're not already.\n",
    "![alt text](./TF_model_in_Flow.png \"TF model in Flow\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2.0
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}